"""
Model estimation module for Demographics & Capital Flows project.

Implements:
1. Baseline Koomen-style CA model with demographic polynomial + EBA controls
2. Pooled GLS with panel-wide AR(1) correction (EBA methodology)
3. Extended models with interest rate channels
4. Diagnostic tests and model comparison
"""

import pandas as pd
import numpy as np
from pathlib import Path
import statsmodels.api as sm
from scipy import linalg

PROCESSED_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/data/processed")
OUTPUT_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/output")


# ---------------------------------------------------------------------------
# Panel GLS Estimation
# ---------------------------------------------------------------------------

class PanelGLS:
    """
    Pooled GLS estimator with AR(1) error correction.

    Following EBA methodology: no fixed effects, but allow for
    panel-wide serial correlation in errors.

    Model: y_it = X_it β + u_it
    where u_it = ρ u_{i,t-1} + ε_it

    Estimation:
    1. OLS to get initial residuals
    2. Estimate panel-wide ρ from residuals
    3. Prais-Winsten / Cochrane-Orcutt transformation
    4. GLS on transformed data
    """

    def __init__(self):
        self.beta = None
        self.se = None
        self.rho = None
        self.resid = None
        self.fitted = None
        self.r_squared = None
        self.r_squared_adj = None
        self.n_obs = None
        self.n_countries = None
        self.feature_names = None

    def fit(self, y, X, entity_ids, time_ids, max_iter=20, tol=1e-6):
        """
        Estimate pooled GLS with iterative Cochrane-Orcutt.

        Parameters
        ----------
        y : array-like, dependent variable (CA/GDP)
        X : array-like, regressors (including polynomial Z's and controls)
        entity_ids : array-like, country identifiers
        time_ids : array-like, year identifiers
        max_iter : int, maximum iterations for Cochrane-Orcutt
        tol : float, convergence tolerance for ρ
        """
        y = np.asarray(y, dtype=float)
        X = np.asarray(X, dtype=float)
        entity_ids = np.asarray(entity_ids)
        time_ids = np.asarray(time_ids)

        # Drop observations with NaN
        mask = ~(np.isnan(y) | np.any(np.isnan(X), axis=1))
        y, X = y[mask], X[mask]
        entity_ids, time_ids = entity_ids[mask], time_ids[mask]

        self.n_obs = len(y)
        self.n_countries = len(np.unique(entity_ids))
        k = X.shape[1]

        # Step 1: Initial OLS
        X_const = sm.add_constant(X)
        ols_result = sm.OLS(y, X_const).fit()
        resid = ols_result.resid

        # Iterative Cochrane-Orcutt
        rho = 0.0
        for iteration in range(max_iter):
            # Step 2: Estimate rho from panel residuals
            rho_new = self._estimate_rho(resid, entity_ids, time_ids)

            if abs(rho_new - rho) < tol:
                break
            rho = rho_new

            # Step 3: Prais-Winsten transformation
            y_t, X_t = self._transform_data(y, X_const, entity_ids, time_ids, rho)

            # Step 4: OLS on transformed data
            gls_result = sm.OLS(y_t, X_t).fit()
            resid = y - X_const @ gls_result.params  # untransformed residuals

        self.rho = rho
        self.beta = gls_result.params[1:]  # exclude constant
        self.constant = gls_result.params[0]
        self.se = gls_result.bse[1:]
        self.se_constant = gls_result.bse[0]
        self.tvalues = self.beta / self.se
        self.pvalues = gls_result.pvalues[1:]
        self.resid = resid
        self.fitted = X_const @ gls_result.params
        self.entity_ids = entity_ids
        self.time_ids = time_ids

        # R-squared (from untransformed data)
        ss_res = np.sum(resid ** 2)
        ss_tot = np.sum((y - np.mean(y)) ** 2)
        self.r_squared = 1 - ss_res / ss_tot
        self.r_squared_adj = 1 - (1 - self.r_squared) * (self.n_obs - 1) / (self.n_obs - k - 1)

        return self

    def _estimate_rho(self, resid, entity_ids, time_ids):
        """Estimate panel-wide AR(1) coefficient from residuals."""
        numerator = 0.0
        denominator = 0.0

        for entity in np.unique(entity_ids):
            mask = entity_ids == entity
            e_resid = resid[mask]
            e_times = time_ids[mask]

            # Sort by time
            order = np.argsort(e_times)
            e_resid = e_resid[order]
            e_times = e_times[order]

            # Only use consecutive time periods
            for t in range(1, len(e_times)):
                if e_times[t] - e_times[t - 1] == 1:
                    numerator += e_resid[t] * e_resid[t - 1]
                    denominator += e_resid[t - 1] ** 2

        if denominator > 0:
            return np.clip(numerator / denominator, -0.99, 0.99)
        return 0.0

    def _transform_data(self, y, X, entity_ids, time_ids, rho):
        """Apply Prais-Winsten transformation."""
        y_t = y.copy()
        X_t = X.copy()

        sqrt_factor = np.sqrt(1 - rho ** 2)

        for entity in np.unique(entity_ids):
            mask = entity_ids == entity
            indices = np.where(mask)[0]
            e_times = time_ids[mask]
            order = np.argsort(e_times)
            sorted_indices = indices[order]

            # First observation: Prais-Winsten adjustment
            i0 = sorted_indices[0]
            y_t[i0] = sqrt_factor * y[i0]
            X_t[i0] = sqrt_factor * X[i0]

            # Subsequent observations: quasi-difference
            for t in range(1, len(sorted_indices)):
                it = sorted_indices[t]
                it_prev = sorted_indices[t - 1]

                if time_ids[it] - time_ids[it_prev] == 1:
                    y_t[it] = y[it] - rho * y[it_prev]
                    X_t[it] = X[it] - rho * X[it_prev]
                else:
                    # Gap in time series — treat as new start
                    y_t[it] = sqrt_factor * y[it]
                    X_t[it] = sqrt_factor * X[it]

        return y_t, X_t

    def summary(self, feature_names=None):
        """Print regression summary."""
        if feature_names is not None:
            self.feature_names = feature_names

        names = self.feature_names if self.feature_names else [f'X{i+1}' for i in range(len(self.beta))]

        print("=" * 70)
        print("Pooled GLS with AR(1) Correction")
        print("=" * 70)
        print(f"N obs: {self.n_obs:,}   N countries: {self.n_countries}")
        print(f"R²: {self.r_squared:.4f}   Adj R²: {self.r_squared_adj:.4f}")
        print(f"ρ (AR1): {self.rho:.4f}")
        print("-" * 70)
        print(f"{'Variable':<25} {'Coef':>10} {'Std Err':>10} {'t-stat':>10} {'p-val':>8}")
        print("-" * 70)
        print(f"{'constant':<25} {self.constant:>10.4f} {self.se_constant:>10.4f} "
              f"{self.constant/self.se_constant:>10.2f} {0:>8.4f}")

        for name, b, se, t, p in zip(names, self.beta, self.se, self.tvalues, self.pvalues):
            sig = '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.1 else ''
            print(f"{name:<25} {b:>10.4f} {se:>10.4f} {t:>10.2f} {p:>8.4f} {sig}")
        print("=" * 70)

        return self

    def to_dataframe(self, feature_names=None):
        """Return results as a DataFrame."""
        names = feature_names or self.feature_names or [f'X{i+1}' for i in range(len(self.beta))]
        return pd.DataFrame({
            'variable': ['constant'] + list(names),
            'coefficient': np.concatenate([[self.constant], self.beta]),
            'std_error': np.concatenate([[self.se_constant], self.se]),
            't_statistic': np.concatenate([[self.constant / self.se_constant], self.tvalues]),
            'p_value': np.concatenate([[np.nan], self.pvalues]),
        })


# ---------------------------------------------------------------------------
# Model specifications
# ---------------------------------------------------------------------------

def estimate_baseline_model(panel_df):
    """
    Estimate baseline Koomen-style CA model:

    CA_it/GDP_it = γ₁Z₁ + γ₂Z₂ + γ₃Z₃ + β'X_it + u_it

    Where X includes available EBA controls.
    Uses a stepwise approach: start with core controls that have good coverage,
    then add additional controls only if they don't severely reduce the sample.
    """
    print("\n" + "=" * 70)
    print("BASELINE MODEL: Demographics + EBA Controls")
    print("=" * 70)

    df = panel_df.dropna(subset=['ca_gdp', 'Z_1', 'Z_2', 'Z_3']).copy()

    # Define regressors
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    # EBA controls ordered by priority and typical coverage
    # Core controls (high coverage, included always):
    core_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth']
    # Secondary controls (add if they don't reduce sample by >30%):
    secondary_controls = ['nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
    # Tertiary controls (sparse coverage, add carefully):
    tertiary_controls = ['output_gap', 'life_expectancy']

    # Start with core
    controls = [c for c in core_controls if c in df.columns and df[c].notna().sum() > 200]
    base_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls).shape[0]

    # Add secondary if they don't cut sample too much
    for c in secondary_controls:
        if c in df.columns and df[c].notna().sum() > 200:
            test_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls + [c]).shape[0]
            if test_n >= 0.7 * base_n:
                controls.append(c)
                base_n = test_n

    # Add tertiary similarly
    for c in tertiary_controls:
        if c in df.columns and df[c].notna().sum() > 200:
            test_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls + [c]).shape[0]
            if test_n >= 0.7 * base_n:
                controls.append(c)
                base_n = test_n

    all_vars = demo_vars + controls
    print(f"  Using controls: {controls}")

    # Drop rows with missing values in any regressor
    df = df.dropna(subset=['ca_gdp'] + all_vars)
    print(f"  Sample: {df['iso3'].nunique()} countries, {len(df):,} obs, "
          f"{df['year'].min()}-{df['year'].max()}")

    y = df['ca_gdp'].values
    X = df[all_vars].values

    model = PanelGLS()
    model.fit(y, X, df['iso3'].values, df['year'].values)
    model.summary(feature_names=all_vars)

    # Store residuals in panel
    df['resid_baseline'] = model.resid
    df['fitted_baseline'] = model.fitted

    return model, df


def estimate_extended_model(panel_df):
    """
    Extended model with interest rate channel and financial openness interactions:

    CA_it/GDP_it = γ₁Z₁ + γ₂Z₂ + γ₃Z₃ + β'X_it + δ(Z×KAOPEN) + u_it
    """
    print("\n" + "=" * 70)
    print("EXTENDED MODEL: Demographics + Interactions")
    print("=" * 70)

    df = panel_df.dropna(subset=['ca_gdp', 'Z_1', 'Z_2', 'Z_3']).copy()

    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    # Use same stepwise control selection as baseline
    core_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth']
    secondary_controls = ['nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
    controls = [c for c in core_controls if c in df.columns and df[c].notna().sum() > 200]
    base_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls).shape[0]
    for c in secondary_controls:
        if c in df.columns and df[c].notna().sum() > 200:
            test_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls + [c]).shape[0]
            if test_n >= 0.7 * base_n:
                controls.append(c)
                base_n = test_n

    # Rate variables (prefer log_lending_rate over raw lending_rate)
    rate_candidates = ['real_bond_10y_diff', 'real_short_3m_diff']
    if 'log_lending_rate' in df.columns and df['log_lending_rate'].notna().sum() > 50:
        rate_candidates.append('log_lending_rate')
    elif 'lending_rate' in df.columns and df['lending_rate'].notna().sum() > 50:
        rate_candidates.append('lending_rate')
    rate_vars = []
    for rv in rate_candidates:
        if rv in df.columns and df[rv].notna().sum() > 50:
            test_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls + rate_vars + [rv]).shape[0]
            if test_n >= 0.5 * base_n:
                rate_vars.append(rv)

    # Pension variables (if available)
    pension_vars = []
    for pv in ['pension_spending_gdp', 'pension_coverage']:
        if pv in df.columns and df[pv].notna().sum() > 100:
            test_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls + rate_vars + pension_vars + [pv]).shape[0]
            if test_n >= 0.5 * base_n:
                pension_vars.append(pv)

    # Create Z×pension interaction terms
    pension_interaction_vars = []
    if 'pension_spending_gdp' in pension_vars:
        for zv in demo_vars:
            iname = f'{zv}_x_pension'
            df[iname] = df[zv] * df['pension_spending_gdp']
            pension_interaction_vars.append(iname)

    # KAOPEN interaction terms
    interaction_vars = []
    for iv in ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']:
        if iv in df.columns and df[iv].notna().sum() > 200:
            interaction_vars.append(iv)

    all_vars = demo_vars + controls + rate_vars + pension_vars + interaction_vars + pension_interaction_vars
    df = df.dropna(subset=['ca_gdp'] + all_vars)

    if len(df) < 100:
        print("  Insufficient observations for extended model")
        return None, df

    print(f"  Variables: {all_vars}")
    print(f"  Sample: {df['iso3'].nunique()} countries, {len(df):,} obs")

    y = df['ca_gdp'].values
    X = df[all_vars].values

    model = PanelGLS()
    model.fit(y, X, df['iso3'].values, df['year'].values)
    model.summary(feature_names=all_vars)

    df['resid_extended'] = model.resid
    df['fitted_extended'] = model.fitted

    return model, df


def estimate_pension_model(panel_df):
    """
    Pension model: Demographics + EBA controls + pension spending + Z×pension interactions.

    Restricted to countries with pension spending data (mostly OECD).
    Tests whether pension system generosity mediates the demographic-CA relationship.
    """
    print("\n" + "=" * 70)
    print("PENSION MODEL: Demographics + EBA + Pension Interactions")
    print("=" * 70)

    df = panel_df.dropna(subset=['ca_gdp', 'Z_1', 'Z_2', 'Z_3', 'pension_spending_gdp']).copy()

    if len(df) < 100:
        print(f"  Insufficient observations with pension data ({len(df)})")
        return None, df

    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    # EBA controls (same stepwise as baseline)
    core_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth']
    secondary_controls = ['nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
    controls = [c for c in core_controls if c in df.columns and df[c].notna().sum() > 50]
    base_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls).shape[0]
    for c in secondary_controls:
        if c in df.columns and df[c].notna().sum() > 50:
            test_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls + [c]).shape[0]
            if test_n >= 0.7 * base_n:
                controls.append(c)
                base_n = test_n

    # Pension level control
    pension_vars = ['pension_spending_gdp']

    # Z × pension_spending_gdp interactions
    pension_interactions = []
    for zv in demo_vars:
        iname = f'{zv}_x_pension'
        df[iname] = df[zv] * df['pension_spending_gdp']
        pension_interactions.append(iname)

    # Z × KAOPEN interactions (if available)
    kaopen_interactions = []
    for iv in ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']:
        if iv in df.columns and df[iv].notna().sum() > 50:
            kaopen_interactions.append(iv)

    all_vars = demo_vars + controls + pension_vars + pension_interactions + kaopen_interactions
    df = df.dropna(subset=['ca_gdp'] + all_vars)

    if len(df) < 100:
        print(f"  Insufficient observations after dropping NAs ({len(df)})")
        return None, df

    print(f"  Variables: {all_vars}")
    print(f"  Sample: {df['iso3'].nunique()} countries, {len(df):,} obs, "
          f"{df['year'].min()}-{df['year'].max()}")

    y = df['ca_gdp'].values
    X = df[all_vars].values

    model = PanelGLS()
    model.fit(y, X, df['iso3'].values, df['year'].values)
    model.summary(feature_names=all_vars)

    df['resid_pension'] = model.resid
    df['fitted_pension'] = model.fitted

    return model, df


def estimate_demographics_only(panel_df):
    """
    Pure demographic model (no controls) — for comparison and to show
    the standalone explanatory power of demographics.
    """
    print("\n" + "=" * 70)
    print("DEMOGRAPHICS-ONLY MODEL")
    print("=" * 70)

    df = panel_df.dropna(subset=['ca_gdp', 'Z_1', 'Z_2', 'Z_3']).copy()
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    y = df['ca_gdp'].values
    X = df[demo_vars].values

    model = PanelGLS()
    model.fit(y, X, df['iso3'].values, df['year'].values)
    model.summary(feature_names=demo_vars)

    return model, df


# ---------------------------------------------------------------------------
# Model comparison and diagnostics
# ---------------------------------------------------------------------------

def compare_models(models_dict):
    """
    Compare multiple model specifications.

    Parameters
    ----------
    models_dict : dict of {name: (model, df)} tuples
    """
    print("\n" + "=" * 70)
    print("MODEL COMPARISON")
    print("=" * 70)

    rows = []
    for name, (model, df) in models_dict.items():
        if model is None:
            continue
        rows.append({
            'Model': name,
            'N obs': model.n_obs,
            'N countries': model.n_countries,
            'R²': model.r_squared,
            'Adj R²': model.r_squared_adj,
            'ρ': model.rho,
            'RMSE': np.sqrt(np.mean(model.resid ** 2)),
        })

    comparison = pd.DataFrame(rows)
    print(comparison.to_string(index=False, float_format='%.4f'))
    return comparison


def country_residual_analysis(model, df, n_top=20):
    """
    Analyze residuals by country to identify over/under-performers.
    """
    df = df.copy()
    df['resid'] = model.resid

    # Average residual by country
    country_resid = df.groupby('iso3').agg(
        mean_resid=('resid', 'mean'),
        std_resid=('resid', 'std'),
        n_obs=('resid', 'count'),
        mean_ca=('ca_gdp', 'mean'),
    ).reset_index()

    country_resid['abs_resid'] = country_resid['mean_resid'].abs()
    country_resid = country_resid.sort_values('abs_resid', ascending=False)

    print(f"\nTop {n_top} countries by absolute mean residual:")
    print(country_resid.head(n_top).to_string(index=False, float_format='%.3f'))

    return country_resid


def extract_demographic_contribution(model, panel_df, feature_names):
    """
    Extract the demographic contribution to CA/GDP for each country-year.

    demo_contribution_it = γ₁Z₁_it + γ₂Z₂_it + γ₃Z₃_it
    """
    z_indices = [i for i, name in enumerate(feature_names) if name.startswith('Z_')]
    z_names = [feature_names[i] for i in z_indices]
    z_betas = model.beta[z_indices]

    df = panel_df[['iso3', 'year'] + z_names].copy()
    df['demo_contribution'] = sum(
        z_betas[i] * df[z_names[i]] for i in range(len(z_names))
    )

    return df


def estimate_nonlinearity_tests(panel_df):
    """
    Test for nonlinear effects in NFA/GDP and life expectancy.

    Estimates four variants of the baseline model:
    1. Baseline + NFA squared (concavity test)
    2. Baseline with NFA split into positive/negative (creditor-debtor asymmetry)
    3. Baseline + life expectancy squared (convexity test)
    4. Baseline + both NFA squared and life expectancy squared (joint test)

    Returns dict of {name: (model, df, feature_names)} and a summary DataFrame.
    """
    print("\n" + "=" * 70)
    print("NONLINEARITY TESTS")
    print("=" * 70)

    df = panel_df.dropna(subset=['ca_gdp', 'Z_1', 'Z_2', 'Z_3']).copy()
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    # Build baseline control set (same stepwise as estimate_baseline_model)
    core_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth']
    secondary_controls = ['nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
    tertiary_controls = ['output_gap', 'life_expectancy']

    controls = [c for c in core_controls if c in df.columns and df[c].notna().sum() > 200]
    base_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls).shape[0]
    for c in secondary_controls:
        if c in df.columns and df[c].notna().sum() > 200:
            test_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls + [c]).shape[0]
            if test_n >= 0.7 * base_n:
                controls.append(c)
                base_n = test_n
    for c in tertiary_controls:
        if c in df.columns and df[c].notna().sum() > 200:
            test_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls + [c]).shape[0]
            if test_n >= 0.7 * base_n:
                controls.append(c)
                base_n = test_n

    results = {}
    summary_rows = []

    # --- Test 1: NFA squared ---
    if 'nfa_gdp_lag_sq' in df.columns and 'nfa_gdp_lag' in controls:
        test_vars = demo_vars + controls + ['nfa_gdp_lag_sq']
        tdf = df.dropna(subset=['ca_gdp'] + test_vars).copy()
        print(f"\n  Test 1: Baseline + NFA² ({len(tdf):,} obs, {tdf['iso3'].nunique()} countries)")

        model = PanelGLS()
        model.fit(tdf['ca_gdp'].values, tdf[test_vars].values,
                  tdf['iso3'].values, tdf['year'].values)
        model.summary(feature_names=test_vars)
        results['Baseline + NFA²'] = (model, tdf)

        # Extract NFA squared coefficient
        idx = test_vars.index('nfa_gdp_lag_sq')
        summary_rows.append({
            'Test': 'NFA/GDP squared',
            'Variable': 'nfa_gdp_lag_sq',
            'Coefficient': model.beta[idx],
            'Std Error': model.se[idx],
            't-stat': model.beta[idx] / model.se[idx],
            'p-value': 2 * (1 - __import__('scipy').stats.t.cdf(
                abs(model.beta[idx] / model.se[idx]), model.n_obs - len(test_vars) - 1)),
            'R²': model.r_squared,
            'N': model.n_obs,
        })

    # --- Test 2: Creditor/Debtor split ---
    if 'nfa_positive' in df.columns and 'nfa_negative' in df.columns and 'nfa_gdp_lag' in controls:
        split_controls = [c for c in controls if c != 'nfa_gdp_lag']
        test_vars = demo_vars + split_controls + ['nfa_positive', 'nfa_negative']
        tdf = df.dropna(subset=['ca_gdp'] + test_vars).copy()
        print(f"\n  Test 2: Creditor/Debtor NFA split ({len(tdf):,} obs, {tdf['iso3'].nunique()} countries)")

        model = PanelGLS()
        model.fit(tdf['ca_gdp'].values, tdf[test_vars].values,
                  tdf['iso3'].values, tdf['year'].values)
        model.summary(feature_names=test_vars)
        results['Creditor/Debtor Split'] = (model, tdf)

        idx_pos = test_vars.index('nfa_positive')
        idx_neg = test_vars.index('nfa_negative')
        for label, idx in [('NFA (creditor, >0)', idx_pos), ('NFA (debtor, <0)', idx_neg)]:
            summary_rows.append({
                'Test': 'Creditor/debtor split',
                'Variable': label,
                'Coefficient': model.beta[idx],
                'Std Error': model.se[idx],
                't-stat': model.beta[idx] / model.se[idx],
                'p-value': 2 * (1 - __import__('scipy').stats.t.cdf(
                    abs(model.beta[idx] / model.se[idx]), model.n_obs - len(test_vars) - 1)),
                'R²': model.r_squared,
                'N': model.n_obs,
            })

    # --- Test 3: Life expectancy squared ---
    if 'life_expectancy_sq' in df.columns and 'life_expectancy' in controls:
        test_vars = demo_vars + controls + ['life_expectancy_sq']
        tdf = df.dropna(subset=['ca_gdp'] + test_vars).copy()
        print(f"\n  Test 3: Baseline + life_expectancy² ({len(tdf):,} obs, {tdf['iso3'].nunique()} countries)")

        model = PanelGLS()
        model.fit(tdf['ca_gdp'].values, tdf[test_vars].values,
                  tdf['iso3'].values, tdf['year'].values)
        model.summary(feature_names=test_vars)
        results['Baseline + LE²'] = (model, tdf)

        idx_le = test_vars.index('life_expectancy')
        idx_sq = test_vars.index('life_expectancy_sq')
        summary_rows.append({
            'Test': 'Life expectancy squared',
            'Variable': 'life_expectancy_sq',
            'Coefficient': model.beta[idx_sq],
            'Std Error': model.se[idx_sq],
            't-stat': model.beta[idx_sq] / model.se[idx_sq],
            'p-value': 2 * (1 - __import__('scipy').stats.t.cdf(
                abs(model.beta[idx_sq] / model.se[idx_sq]), model.n_obs - len(test_vars) - 1)),
            'R²': model.r_squared,
            'N': model.n_obs,
        })
        # Also report the linear term (which may change)
        summary_rows.append({
            'Test': 'Life expectancy squared',
            'Variable': 'life_expectancy (linear)',
            'Coefficient': model.beta[idx_le],
            'Std Error': model.se[idx_le],
            't-stat': model.beta[idx_le] / model.se[idx_le],
            'p-value': 2 * (1 - __import__('scipy').stats.t.cdf(
                abs(model.beta[idx_le] / model.se[idx_le]), model.n_obs - len(test_vars) - 1)),
            'R²': model.r_squared,
            'N': model.n_obs,
        })

    # --- Test 4: Joint (both squared terms) ---
    if ('nfa_gdp_lag_sq' in df.columns and 'life_expectancy_sq' in df.columns
            and 'nfa_gdp_lag' in controls and 'life_expectancy' in controls):
        test_vars = demo_vars + controls + ['nfa_gdp_lag_sq', 'life_expectancy_sq']
        tdf = df.dropna(subset=['ca_gdp'] + test_vars).copy()
        print(f"\n  Test 4: Baseline + NFA² + LE² ({len(tdf):,} obs, {tdf['iso3'].nunique()} countries)")

        model = PanelGLS()
        model.fit(tdf['ca_gdp'].values, tdf[test_vars].values,
                  tdf['iso3'].values, tdf['year'].values)
        model.summary(feature_names=test_vars)
        results['Baseline + NFA² + LE²'] = (model, tdf)

        for var_name in ['nfa_gdp_lag_sq', 'life_expectancy_sq']:
            idx = test_vars.index(var_name)
            summary_rows.append({
                'Test': 'Joint (NFA² + LE²)',
                'Variable': var_name,
                'Coefficient': model.beta[idx],
                'Std Error': model.se[idx],
                't-stat': model.beta[idx] / model.se[idx],
                'p-value': 2 * (1 - __import__('scipy').stats.t.cdf(
                    abs(model.beta[idx] / model.se[idx]), model.n_obs - len(test_vars) - 1)),
                'R²': model.r_squared,
                'N': model.n_obs,
            })

    # Summary
    summary = pd.DataFrame(summary_rows)
    if len(summary) > 0:
        print("\n" + "=" * 70)
        print("NONLINEARITY TEST SUMMARY")
        print("=" * 70)
        print(summary.to_string(index=False, float_format='%.4f'))

    return results, summary


def estimate_rate_channel_tests(panel_df):
    """
    Test alternative interest rate measures in the extended model.

    Runs three tests:
    - Model 3b: Extended model with real_bond_10y_diff replacing log_lending_rate
    - Model 3c: Extended model with term_spread as rate variable
    - Two-stage Carvalho channel test with bond yields

    Returns (results_dict, summary_df).
    """
    from scipy import stats

    print("\n" + "=" * 70)
    print("RATE CHANNEL TESTS (Alternative Interest Rate Measures)")
    print("=" * 70)

    df = panel_df.dropna(subset=['ca_gdp', 'Z_1', 'Z_2', 'Z_3']).copy()
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    # Build control set (same stepwise as extended model)
    core_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth']
    secondary_controls = ['nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
    controls = [c for c in core_controls if c in df.columns and df[c].notna().sum() > 200]
    base_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls).shape[0]
    for c in secondary_controls:
        if c in df.columns and df[c].notna().sum() > 200:
            test_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls + [c]).shape[0]
            if test_n >= 0.7 * base_n:
                controls.append(c)
                base_n = test_n

    interaction_vars = [iv for iv in ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
                        if iv in df.columns and df[iv].notna().sum() > 200]

    results = {}
    summary_rows = []

    # --- Model 3b: real_bond_10y_diff ---
    if 'real_bond_10y_diff' in df.columns and df['real_bond_10y_diff'].notna().sum() > 100:
        rate_var = 'real_bond_10y_diff'
        all_vars = demo_vars + controls + [rate_var] + interaction_vars
        tdf = df.dropna(subset=['ca_gdp'] + all_vars).copy()

        if len(tdf) >= 100:
            print(f"\n  Model 3b: Extended + real bond yield differential")
            print(f"    Sample: {tdf['iso3'].nunique()} countries, {len(tdf):,} obs, "
                  f"{tdf['year'].min()}-{tdf['year'].max()}")

            model = PanelGLS()
            model.fit(tdf['ca_gdp'].values, tdf[all_vars].values,
                      tdf['iso3'].values, tdf['year'].values)
            model.summary(feature_names=all_vars)
            results['Model 3b (Bond Yields)'] = (model, tdf)

            idx = all_vars.index(rate_var)
            p_val = 2 * (1 - stats.t.cdf(
                abs(model.beta[idx] / model.se[idx]), model.n_obs - len(all_vars) - 1))
            summary_rows.append({
                'Test': 'Model 3b (real_bond_10y_diff)',
                'Variable': rate_var,
                'Coefficient': model.beta[idx],
                'Std_Error': model.se[idx],
                't_stat': model.beta[idx] / model.se[idx],
                'p_value': p_val,
                'R_squared': model.r_squared,
                'N_obs': model.n_obs,
                'N_countries': model.n_countries,
            })
        else:
            print(f"  Model 3b skipped: insufficient obs ({len(tdf)}) with bond yields")

    # --- Model 3c: term_spread ---
    if 'term_spread' in df.columns and df['term_spread'].notna().sum() > 100:
        rate_var = 'term_spread'
        all_vars = demo_vars + controls + [rate_var] + interaction_vars
        tdf = df.dropna(subset=['ca_gdp'] + all_vars).copy()

        if len(tdf) >= 100:
            print(f"\n  Model 3c: Extended + term spread")
            print(f"    Sample: {tdf['iso3'].nunique()} countries, {len(tdf):,} obs, "
                  f"{tdf['year'].min()}-{tdf['year'].max()}")

            model = PanelGLS()
            model.fit(tdf['ca_gdp'].values, tdf[all_vars].values,
                      tdf['iso3'].values, tdf['year'].values)
            model.summary(feature_names=all_vars)
            results['Model 3c (Term Spread)'] = (model, tdf)

            idx = all_vars.index(rate_var)
            p_val = 2 * (1 - stats.t.cdf(
                abs(model.beta[idx] / model.se[idx]), model.n_obs - len(all_vars) - 1))
            summary_rows.append({
                'Test': 'Model 3c (term_spread)',
                'Variable': rate_var,
                'Coefficient': model.beta[idx],
                'Std_Error': model.se[idx],
                't_stat': model.beta[idx] / model.se[idx],
                'p_value': p_val,
                'R_squared': model.r_squared,
                'N_obs': model.n_obs,
                'N_countries': model.n_countries,
            })
        else:
            print(f"  Model 3c skipped: insufficient obs ({len(tdf)}) with term spread")

    # --- Two-stage Carvalho channel test with bond yields ---
    if 'real_bond_10y_diff' in df.columns:
        s1_vars = demo_vars + [c for c in controls if c != 'kaopen']
        tdf = df.dropna(subset=['real_bond_10y_diff', 'ca_gdp'] + s1_vars).copy()

        if len(tdf) >= 100:
            print(f"\n  Two-Stage Carvalho Test (bond yields)")
            print(f"    Sample: {tdf['iso3'].nunique()} countries, {len(tdf):,} obs")

            # Stage 1: Demographics → real bond yield differential
            print(f"\n    Stage 1: Z → real_bond_10y_diff")
            s1_model = PanelGLS()
            s1_model.fit(tdf['real_bond_10y_diff'].values, tdf[s1_vars].values,
                         tdf['iso3'].values, tdf['year'].values)
            s1_model.summary(feature_names=s1_vars)

            # Record stage 1 Z p-values
            for i, zv in enumerate(demo_vars):
                idx = s1_vars.index(zv)
                p_val = 2 * (1 - stats.t.cdf(
                    abs(s1_model.beta[idx] / s1_model.se[idx]),
                    s1_model.n_obs - len(s1_vars) - 1))
                summary_rows.append({
                    'Test': 'Two-stage S1 (Z → bond yield)',
                    'Variable': zv,
                    'Coefficient': s1_model.beta[idx],
                    'Std_Error': s1_model.se[idx],
                    't_stat': s1_model.beta[idx] / s1_model.se[idx],
                    'p_value': p_val,
                    'R_squared': s1_model.r_squared,
                    'N_obs': s1_model.n_obs,
                    'N_countries': s1_model.n_countries,
                })

            # Stage 2: Fitted bond yields → CA/GDP
            print(f"\n    Stage 2: fitted_bond_yield → CA/GDP")
            tdf['fitted_bond_yield'] = s1_model.fitted
            s2_control_vars = [c for c in controls if c != 'kaopen']
            s2_vars = s2_control_vars + ['fitted_bond_yield']
            tdf2 = tdf.dropna(subset=['ca_gdp'] + s2_vars)

            s2_model = PanelGLS()
            s2_model.fit(tdf2['ca_gdp'].values, tdf2[s2_vars].values,
                         tdf2['iso3'].values, tdf2['year'].values)
            s2_model.summary(feature_names=s2_vars)

            idx = s2_vars.index('fitted_bond_yield')
            p_val = 2 * (1 - stats.t.cdf(
                abs(s2_model.beta[idx] / s2_model.se[idx]),
                s2_model.n_obs - len(s2_vars) - 1))
            summary_rows.append({
                'Test': 'Two-stage S2 (fitted yield → CA)',
                'Variable': 'fitted_bond_yield',
                'Coefficient': s2_model.beta[idx],
                'Std_Error': s2_model.se[idx],
                't_stat': s2_model.beta[idx] / s2_model.se[idx],
                'p_value': p_val,
                'R_squared': s2_model.r_squared,
                'N_obs': s2_model.n_obs,
                'N_countries': s2_model.n_countries,
            })

            results['Two-Stage (Bond Yields)'] = {
                'stage1': s1_model,
                'stage2': s2_model,
            }
        else:
            print(f"  Two-stage test skipped: insufficient obs ({len(tdf)}) with bond yields")

    # --- Also re-run Model 3 (lending rates) on the bond yield sample for comparison ---
    if ('log_lending_rate' in df.columns and 'real_bond_10y_diff' in df.columns
            and df['log_lending_rate'].notna().sum() > 100):
        bond_countries = df.loc[df['real_bond_10y_diff'].notna(), 'iso3'].unique()
        comp_df = df[df['iso3'].isin(bond_countries)].copy()
        rate_var = 'log_lending_rate'
        all_vars = demo_vars + controls + [rate_var] + interaction_vars
        tdf = comp_df.dropna(subset=['ca_gdp'] + all_vars).copy()

        if len(tdf) >= 100:
            print(f"\n  Model 3 (lending rate) on bond-yield sample for comparison")
            print(f"    Sample: {tdf['iso3'].nunique()} countries, {len(tdf):,} obs")

            model = PanelGLS()
            model.fit(tdf['ca_gdp'].values, tdf[all_vars].values,
                      tdf['iso3'].values, tdf['year'].values)
            model.summary(feature_names=all_vars)
            results['Model 3 (Lending, bond sample)'] = (model, tdf)

            idx = all_vars.index(rate_var)
            p_val = 2 * (1 - stats.t.cdf(
                abs(model.beta[idx] / model.se[idx]), model.n_obs - len(all_vars) - 1))
            summary_rows.append({
                'Test': 'Model 3 (log_lending on bond sample)',
                'Variable': rate_var,
                'Coefficient': model.beta[idx],
                'Std_Error': model.se[idx],
                't_stat': model.beta[idx] / model.se[idx],
                'p_value': p_val,
                'R_squared': model.r_squared,
                'N_obs': model.n_obs,
                'N_countries': model.n_countries,
            })

    # --- Model 3d: real_short_3m_diff ---
    if 'real_short_3m_diff' in df.columns and df['real_short_3m_diff'].notna().sum() > 100:
        rate_var = 'real_short_3m_diff'
        all_vars = demo_vars + controls + [rate_var] + interaction_vars
        tdf = df.dropna(subset=['ca_gdp'] + all_vars).copy()

        if len(tdf) >= 100:
            print(f"\n  Model 3d: Extended + real short rate differential")
            print(f"    Sample: {tdf['iso3'].nunique()} countries, {len(tdf):,} obs")

            model = PanelGLS()
            model.fit(tdf['ca_gdp'].values, tdf[all_vars].values,
                      tdf['iso3'].values, tdf['year'].values)
            model.summary(feature_names=all_vars)

            idx = all_vars.index(rate_var)
            p_val = 2 * (1 - stats.t.cdf(
                abs(model.beta[idx] / model.se[idx]), model.n_obs - len(all_vars) - 1))
            summary_rows.append({
                'Test': 'Model 3d (real_short_3m_diff)',
                'Variable': rate_var,
                'Coefficient': model.beta[idx],
                'Std_Error': model.se[idx],
                't_stat': model.beta[idx] / model.se[idx],
                'p_value': p_val,
                'R_squared': model.r_squared,
                'N_obs': model.n_obs,
                'N_countries': model.n_countries,
            })

    # --- Carry trade variables ---
    for carry_var in ['fx_hedged_vs_jpn', 'fx_hedged_vs_usa', 'carry_vs_jpn', 'carry_vs_usa']:
        if carry_var not in df.columns or df[carry_var].notna().sum() < 100:
            continue
        all_vars = demo_vars + controls + [carry_var] + interaction_vars
        tdf = df.dropna(subset=['ca_gdp'] + all_vars).copy()

        if len(tdf) >= 100:
            print(f"\n  Model with {carry_var}")
            print(f"    Sample: {tdf['iso3'].nunique()} countries, {len(tdf):,} obs")

            model = PanelGLS()
            model.fit(tdf['ca_gdp'].values, tdf[all_vars].values,
                      tdf['iso3'].values, tdf['year'].values)
            model.summary(feature_names=all_vars)

            idx = all_vars.index(carry_var)
            p_val = 2 * (1 - stats.t.cdf(
                abs(model.beta[idx] / model.se[idx]), model.n_obs - len(all_vars) - 1))
            summary_rows.append({
                'Test': f'Carry ({carry_var})',
                'Variable': carry_var,
                'Coefficient': model.beta[idx],
                'Std_Error': model.se[idx],
                't_stat': model.beta[idx] / model.se[idx],
                'p_value': p_val,
                'R_squared': model.r_squared,
                'N_obs': model.n_obs,
                'N_countries': model.n_countries,
            })

    # Summary
    summary = pd.DataFrame(summary_rows)
    if len(summary) > 0:
        print("\n" + "=" * 70)
        print("RATE CHANNEL TEST SUMMARY")
        print("=" * 70)
        print(summary.to_string(index=False, float_format='%.4f'))

    return results, summary


def project_rate_channel(panel_df, polys_df):
    """
    Project demographic pressure on bond yield differentials through 2060.

    Uses Stage 1 coefficients from the two-stage Carvalho test
    (demographics → real bond yield differential) to project what
    demographics alone imply for future yield differentials.

    Returns DataFrame with columns: iso3, year, demo_rate_effect, ca_rate_effect.
    """
    from scipy import stats

    df = panel_df.dropna(subset=['ca_gdp', 'Z_1', 'Z_2', 'Z_3']).copy()
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    # Controls for S1 (no kaopen — rates don't depend on openness)
    s1_controls = ['fiscal_bal_gdp', 'expected_growth', 'nfa_gdp_lag', 'log_rel_opw']
    s1_controls = [c for c in s1_controls if c in df.columns and df[c].notna().sum() > 200]
    s1_vars = demo_vars + s1_controls

    tdf = df.dropna(subset=['real_bond_10y_diff'] + s1_vars).copy()
    if len(tdf) < 100:
        print("  Insufficient observations for rate channel projection")
        return None

    # Estimate S1
    s1 = PanelGLS()
    s1.fit(tdf['real_bond_10y_diff'].values, tdf[s1_vars].values,
           tdf['iso3'].values, tdf['year'].values)

    # Get Z coefficients and Model 3b rate-to-CA coefficient (0.127)
    z_coefs = {s1_vars[i]: s1.beta[i] for i in range(len(s1_vars))}
    rate_to_ca = 0.127  # from Model 3b

    # Project using polynomial data
    focus = ['JPN', 'DEU', 'USA', 'KOR', 'CHN', 'IND', 'IDN', 'NGA',
             'BRA', 'GBR', 'AUS', 'FRA', 'ZAF', 'SAU', 'MEX']
    proj_years = list(range(2000, 2065, 5))

    rows = []
    for iso3 in focus:
        cdf = polys_df[polys_df['iso3'] == iso3]
        if len(cdf) == 0:
            continue
        for year in proj_years:
            yr = cdf[cdf['year'] == year]
            if len(yr) == 0:
                continue
            demo_effect = sum(z_coefs[zv] * yr[zv].values[0] for zv in demo_vars)
            rows.append({
                'iso3': iso3,
                'year': year,
                'demo_rate_effect': demo_effect,
                'ca_rate_effect': demo_effect * rate_to_ca,
            })

    proj = pd.DataFrame(rows)
    if len(proj) > 0:
        proj.to_csv(OUTPUT_DIR / "tables" / "rate_channel_projections.csv", index=False)

        pivot = proj.pivot(index='iso3', columns='year', values='demo_rate_effect')
        print("\n  Projected demographic pressure on bond yield differentials (pp):")
        print(pivot.to_string(float_format='%.2f'))

    return proj


if __name__ == "__main__":
    # Load panel and estimate
    panel = pd.read_csv(PROCESSED_DIR / "full_panel.csv")
    baseline, df_base = estimate_baseline_model(panel)
